import pandas as pd
import numpy as np
from transformers import pipeline, logging
import torch
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
pd.set_option('future.no_silent_downcasting', True)
pd.options.mode.chained_assignment = None 
logging.set_verbosity_error() 

###########################
## Set device from CLI args
###########################
import os
import argparse
def get_device():
    # Allow --device on the command line
    parser = argparse.ArgumentParser()
    parser.add_argument("--device", default=None, help="Device to use: cpu, mps, cuda")
    args, _ = parser.parse_known_args()  # ignore unknown args so script still works normally
    # Priority: CLI argument > DEVICE environment variable > default=cpu
    return args.device or os.getenv("DEVICE", "cuda")
DEVICE = get_device()
print(f"[INFO] Using DEVICE={DEVICE}")
# PyTorch setup
import torch
if DEVICE == "cpu":
    device = torch.device("cpu")
elif DEVICE == "mps":
    device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
elif DEVICE == "cuda":
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
else:
    raise ValueError(f"Unknown device {DEVICE}")
print(f"[INFO] Torch device set to {device}")

###########################
## Data, functions, etc.
###########################
# Load data
test = pd.read_csv('./data/polnli_test_results.csv')

# Convert to dictionary of document pairs to pass through the pipeline
docs_dict = [{'text':test.loc[i, 'premise'], 'text_pair':test.loc[i, 'augmented_hypothesis']} for i in test.index]

def label_docs(model, docs_dict, batch_size = 32, device = device):
    """
    Passes documents through the pipeline. Returns a list of entail, not_entail labels
    """
    pipe = pipeline(task = 'text-classification', model = model, 
                    batch_size = batch_size, device = device, 
                    max_length = 512, truncation = True, 
                    torch_dtype = torch.bfloat16)
    res = pipe(docs_dict)
    res = [result['label'] for result in res]
    return res

# Models that will be tested
models = ["mlburnham/Political_DEBATE_ModernBERT_base_v1.0",
          "mlburnham/Political_DEBATE_ModernBERT_large_v1.0"]

# Column names that will hold results
columns = ['base_modern',
          'large_modern']

# for each model, classify documents and return labels to the test dataframe
for modname, col in zip(models, columns):
    res = label_docs(modname, docs_dict, device = device)
    test[col] = res
    test[col] = test[col].replace({'entailment': 0, 'not_entailment': 1})
    print(modname + ' complete.')

test.to_csv('./data/plonli_test_results.csv', index = False)